{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用朴素贝叶斯过滤垃圾邮件"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**说明:**\n",
    "\n",
    "将 `email` 文件夹放在当前目录下。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import re"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义创建列表函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def createVocabList(dataSet):\n",
    "    vocabSet = set([])  #create empty set\n",
    "    for document in dataSet:\n",
    "        vocabSet = vocabSet | set(document) #union of the two sets\n",
    "    return list(vocabSet)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义词集模型函数（set-of-words）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def setOfWords2Vec(vocabList, inputSet):\n",
    "    returnVec = [0]*len(vocabList)\n",
    "    for word in inputSet:\n",
    "        if word in vocabList:\n",
    "            returnVec[vocabList.index(word)] = 1\n",
    "        else: print(\"the word: %s is not in my Vocabulary!\" % word)\n",
    "    return returnVec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义词带模型函数（bag-of-words）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def bagOfWords2VecMN(vocabList, inputSet):\n",
    "    returnVec = [0]*len(vocabList)\n",
    "    for word in inputSet:\n",
    "        if word in vocabList:\n",
    "            returnVec[vocabList.index(word)] += 1\n",
    "    return returnVec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义朴素贝叶斯分类器训练函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def trainNB(trainMatrix, trainCategory):\n",
    "    numTrainDocs = len(trainMatrix)\n",
    "    numWords = len(trainMatrix[0])\n",
    "    pAbusive = sum(trainCategory)/float(numTrainDocs)\n",
    "    p0Num = np.ones(numWords); p1Num = np.ones(numWords)      #change to np.ones()\n",
    "    p0Denom = 2.0; p1Denom = 2.0                        #change to 2.0\n",
    "    for i in range(numTrainDocs):\n",
    "        if trainCategory[i] == 1:\n",
    "            p1Num += trainMatrix[i]\n",
    "            p1Denom += sum(trainMatrix[i])\n",
    "        else:\n",
    "            p0Num += trainMatrix[i]\n",
    "            p0Denom += sum(trainMatrix[i])\n",
    "    p1Vect = np.log(p1Num/p1Denom)          #change to np.log()\n",
    "    p0Vect = np.log(p0Num/p0Denom)          #change to np.log()\n",
    "    return p0Vect, p1Vect, pAbusive"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义朴素贝叶斯分类器预测函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1):\n",
    "    p1 = sum(vec2Classify * p1Vec) + np.log(pClass1)    #element-wise mult\n",
    "    p0 = sum(vec2Classify * p0Vec) + np.log(1.0 - pClass1)\n",
    "    if p1 > p0:\n",
    "        return 1\n",
    "    else:\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用朴素贝叶斯过滤垃圾邮件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def textParse(bigString):    #input is big string, #output is word list\n",
    "    import re\n",
    "    listOfTokens = re.split(r'\\W+', bigString)\n",
    "    return [tok.lower() for tok in listOfTokens if len(tok) > 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def spamTest():\n",
    "    docList = []; classList = []; fullText = []\n",
    "    for i in range(1, 26):\n",
    "        wordList = textParse(open('email/spam/%d.txt' % i, encoding=\"ISO-8859-1\").read())\n",
    "        docList.append(wordList)\n",
    "        fullText.extend(wordList)\n",
    "        classList.append(1)\n",
    "        wordList = textParse(open('email/ham/%d.txt' % i, encoding=\"ISO-8859-1\").read())\n",
    "        docList.append(wordList)\n",
    "        fullText.extend(wordList)\n",
    "        classList.append(0)\n",
    "    vocabList = createVocabList(docList)#create vocabulary\n",
    "    trainingSet = range(50); testSet = []           #create test set\n",
    "    for i in range(10):\n",
    "        randIndex = int(np.random.uniform(0, len(trainingSet)))\n",
    "        testSet.append(trainingSet[randIndex])\n",
    "        del(list(trainingSet)[randIndex])\n",
    "    trainMat = []; trainClasses = []\n",
    "    for docIndex in trainingSet:#train the classifier (get probs) trainNB0\n",
    "        trainMat.append(bagOfWords2VecMN(vocabList, docList[docIndex]))\n",
    "        trainClasses.append(classList[docIndex])\n",
    "    p0V, p1V, pSpam = trainNB(np.array(trainMat), np.array(trainClasses))\n",
    "    errorCount = 0\n",
    "    for docIndex in testSet:        #classify the remaining items\n",
    "        wordVector = bagOfWords2VecMN(vocabList, docList[docIndex])\n",
    "        if classifyNB(np.array(wordVector), p0V, p1V, pSpam) != classList[docIndex]:\n",
    "            errorCount += 1\n",
    "            print(\"classification error\", docList[docIndex])\n",
    "    print('the error rate is: ', float(errorCount)/len(testSet))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the error rate is:  0.0\n"
     ]
    }
   ],
   "source": [
    "spamTest()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
